"""
callbacks for the neurite project

If you use this code, please cite 

Dalca AV, Guttag J, Sabuncu MR
Anatomical Priors in Convolutional Networks for Unsupervised Biomedical Segmentation, 
CVPR 2018

Contact: adalca [at] csail [dot] mit [dot] edu
License: GPLv3
"""

# internal python imports
import sys
import time
import warnings

# third party imports
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt

# local (our) imports
import neurite as ne
from pystrum.pytools import timer


class ModelWeightCheck(keras.callbacks.Callback):
    """
        check model weights for nan and infinite entries
    """

    def __init__(self,
                 weight_diff=False,
                 at_batch_end=False,                 
                 at_epoch_end=True):
        """
        Params:
            at_batch_end: None or number indicate when to execute
                (i.e. at_batch_end = 10 means execute every 10 batches)
            at_epoch_end: logical, whether to execute at epoch end
        """
        super(ModelWeightCheck, self).__init__()
        self.at_batch_end = at_batch_end
        self.at_epoch_end = at_epoch_end
        self.current_epoch = 0
        self.weight_diff = weight_diff
        self.wts = None

    def on_batch_end(self, batch, logs=None):
        if self.at_batch_end is not None and np.mod(batch + 1, self.at_batch_end) == 0:
            self.on_model_check(self.current_epoch, batch + 1, logs=logs)

    def on_epoch_end(self, epoch, logs=None):
        if self.at_epoch_end:
            self.on_model_check(epoch, 0, logs=logs)
        self.current_epoch = epoch

    def on_model_check(self, epoch, iter, logs=None):
        for layer in self.model.layers:
            for wt in layer.get_weights():
                assert ~np.any(np.isnan(wt)), 'Found nan weights in model layer %s' % layer.name
                assert np.all(np.isfinite(wt)), 'Found infinite weights in model layer %s' % layer.name

        # compute max change
        if self.weight_diff:
            wts = self.model.get_weights()
            diff = -np.inf

            if self.wts is not None:
                for wi, w in enumerate(wts):
                    if len(w) > 0:
                        for si, sw in enumerate(w):
                            diff = np.maximum(diff, np.max(np.abs(sw - self.wts[wi][si])))
                            
            self.wts = wts
            logs['max_diff'] = diff
            # print("max diff", diff)


class CheckLossTrend(keras.callbacks.Callback):
    """
        check model weights for nan and infinite entries
    """

    def __init__(self,
                 at_batch_end=True,
                 at_epoch_end=False,
                 nb_std_err=2,
                 loss_window=10):
        """
        Params:
            at_batch_end: None or number indicate when to execute
                (i.e. at_batch_end = 10 means execute every 10 batches)
            at_epoch_end: logical, whether to execute at epoch end
        """
        super(CheckLossTrend, self).__init__()
        self.at_batch_end = at_batch_end
        self.at_epoch_end = at_epoch_end
        self.current_epoch = 0
        self.loss_window = loss_window
        self.nb_std_err = nb_std_err
        self.losses = []

    def on_batch_end(self, batch, logs=None):
        if self.at_batch_end is not None and np.mod(batch + 1, self.at_batch_end) == 0:
            self.on_model_check(self.current_epoch, batch + 1, logs=logs)

    def on_epoch_end(self, epoch, logs=None):
        if self.at_epoch_end:
            self.on_model_check(epoch, 0, logs=logs)
        self.current_epoch = epoch

    def on_model_check(self, epoch, iter, logs=None):
        if len(self.losses) < self.loss_window:
            self.losses = [*self.losses, logs['loss']]
        else:
            losses_mean = np.mean(self.losses)
            losses_std = np.std(self.losses)
            this_loss = logs['loss']

            if (this_loss) > (losses_mean + self.nb_std_err * losses_std):
                print(logs)
                err = "Found loss %f, which is much higher than %f + %f " % (this_loss, losses_mean, losses_std)
                # raise ValueError(err)
                print(err, file=sys.stderr)
            
            if (this_loss - losses_mean) > (losses_mean * 100):
                err = "Found loss %f, which is much higher than %f * 100 " % (this_loss, losses_mean)
                raise ValueError(err)

            # cut the first loss and stack athe latest loss.
            self.losses = [*self.losses[1:], logs['loss']]


class PlotTestSlices(keras.callbacks.Callback):
    '''
    plot slices of a test subject from several directions
    '''

    def __init__(self,
                 savefilepath,
                 generator,
                 vol_size,
                 run,   # object with fields: patch_size, patch_stride, grid_size
                 data,  # object with fields:
                 at_batch_end=None,     # None or number indicate when to execute (i.e. at_batch_end = 10 means execute every 10 batches)
                 at_epoch_end=True,     # logical, whether to execute at epoch end
                 verbose=False,
                 period=1,
                 prior=None):
        """
        Parameteres:
            savefilepath,
            generator,
            vol_size,
            run: object with fields: patch_size, patch_stride, grid_size
            data: object with fields:
            at_batch_end=None: None or number indicate when to execute (i.e. at_batch_end = 10 means execute every 10 batches)
            at_epoch_end=True: logical, whether to execute at epoch end
            verbose=False:
            period=1
            prior=None
        """

        super().__init__()

        # save some parameters
        self.savefilepath = savefilepath
        self.generator = generator
        self.vol_size = vol_size

        self.run = run
        self.data = data

        self.at_batch_end = at_batch_end
        self.at_epoch_end = at_epoch_end
        self.current_epoch = 0
        self.period = period

        self.verbose = verbose

        # prepare prior
        self.prior = None
        if prior is not None:
            data = np.load(prior)
            loc_vol = data['prior']
            self.prior = np.expand_dims(loc_vol, axis=0) # reshape for model

    def on_batch_end(self, batch, logs={}):
        if self.at_batch_end is not None and np.mod(batch + 1, self.at_batch_end) == 0:
            self.on_plot_save(self.current_epoch, batch + 1, logs=logs)

    def on_epoch_end(self, epoch, logs={}):
        if self.at_epoch_end and np.mod(epoch + 1, self.period) == 0:
            self.on_plot_save(epoch, 0, logs=logs)
        self.current_epoch = epoch

    def on_plot_save(self, epoch, iter, logs={}):
        # import neuron sandbox
        # has to be here, can't be at the top, due to cyclical imports (??)
        # TODO: should just pass the function to compute the figures given the model and generator

        with timer.Timer('plot callback', self.verbose):
            if len(self.run.grid_size) == 3:
                collapse_2d = [0, 1, 2]
            else:
                collapse_2d = [2]

            # TODO: show_example_prediction_result is actually in neuron_sandbox for now
            exampl = show_example_prediction_result(self.model,
                                                                self.generator,
                                                                self.run,
                                                                self.data,
                                                                test_batch_size=1,
                                                                test_model_names=None,
                                                                test_grid_size=self.run.grid_size,
                                                                ccmap=None,
                                                                collapse_2d=collapse_2d,
                                                                slice_nr=None,
                                                                plt_width=17,
                                                                verbose=self.verbose)

            # save, then close
            figs = exampl[1:]
            for idx, fig in enumerate(figs):
                dirn = "dirn_%d" % idx
                slice_nr = 0
                filename = self.savefilepath.format(epoch=epoch, iter=iter, axis=dirn, slice_nr=slice_nr)
                fig.savefig(filename)
            plt.close()


class PredictMetrics(keras.callbacks.Callback):
    '''
    Compute metrics, like Dice, and save to CSV/log

    '''

    def __init__(self,
                 filepath,
                 metrics,
                 data_generator,
                 nb_samples,
                 nb_labels,
                 batch_size,
                 label_ids=None,
                 vol_params=None,
                 at_batch_end=None,
                 at_epoch_end=True,
                 period=1,
                 verbose=False):
        """
        Parameters:
            filepath: filepath with epoch and metric
            metrics: list of metrics (functions)
            data_generator: validation generator
            nb_samples: number of validation samples - volumes or batches
                depending on whether vol_params is passed or not
            nb_labels: number of labels
            batch_size:
            label_ids=None:
            vol_params=None:
            at_batch_end=None: None or number indicate when to execute
                (i.e. at_batch_end = 10 means execute every 10 batches)
            at_epoch_end=True: logical, whether to execute at epoch end
            verbose=False
        """

        # pass in the parameters to object variables
        self.metrics = metrics
        self.data_generator = data_generator
        self.nb_samples = nb_samples
        self.filepath = filepath
        self.nb_labels = nb_labels
        if label_ids is None:
            self.label_ids = list(range(nb_labels))
        else:
            self.label_ids = label_ids
        self.vol_params = vol_params

        self.current_epoch = 1
        self.at_batch_end = at_batch_end
        self.at_epoch_end = at_epoch_end
        self.batch_size = batch_size
        self.period = period

        self.verbose = verbose

    def on_batch_end(self, batch, logs={}):
        if self.at_batch_end is not None and np.mod(batch + 1, self.at_batch_end) == 0:
            self.on_metric_call(self.current_epoch, batch + 1, logs=logs)

    def on_epoch_end(self, epoch, logs={}):
        if self.at_epoch_end and np.mod(epoch + 1, self.period) == 0:
            self.on_metric_call(epoch, 0, logs=logs)
        self.current_epoch = epoch

    def on_metric_call(self, epoch, iter, logs={}):
        """ compute metrics on several predictions """
        with timer.Timer('predict metrics callback', self.verbose):

            # prepare metric
            met = np.zeros((self.nb_samples, self.nb_labels, len(self.metrics)))

            # generate predictions
            # the idea is to predict either a full volume or just a slice,
            # depending on what we need
            gen = _generate_predictions(self.model,
                                        self.data_generator,
                                        self.batch_size,
                                        self.nb_samples,
                                        self.vol_params)
            batch_idx = 0
            for (vol_true, vol_pred) in gen:
                for idx, metric in enumerate(self.metrics):
                    met[batch_idx, :, idx] = metric(vol_true, vol_pred)
                batch_idx += 1

            # write metric to csv file
            if self.filepath is not None:
                for idx, metric in enumerate(self.metrics):
                    filen = self.filepath.format(epoch=epoch, iter=iter, metric=metric.__name__)
                    np.savetxt(filen, met[:, :, idx], fmt='%f', delimiter=',')
            else:
                meanmet = np.nanmean(met, axis=0)
                for midx, metric in enumerate(self.metrics):
                    for idx in range(self.nb_labels):
                        varname = '%s_label_%d' % (metric.__name__, self.label_ids[idx])
                        logs[varname] = meanmet[idx, midx]


class ModelCheckpoint(keras.callbacks.Callback):
    """
    A modification of keras' ModelCheckpoint, but allow for saving on_batch_end
    changes include:
    - optional at_batch_end, at_epoch_end arguments,
    - filename now must includes 'iter'

    Save the model after every epoch.
    `filepath` can contain named formatting options,
    which will be filled the value of `epoch` and
    keys in `logs` (passed in `on_epoch_end`).
    For example: if `filepath` is `weights.{epoch:02d}-{val_loss:.2f}.hdf5`,
    then the model checkpoints will be saved with the epoch number and
    the validation loss in the filename.
    # Arguments
        filepath: string, path to save the model file.
        monitor: quantity to monitor.
        verbose: verbosity mode, 0 or 1.
        save_best_only: if `save_best_only=True`,
            the latest best model according to
            the quantity monitored will not be overwritten.
        mode: one of {auto, min, max}.
            If `save_best_only=True`, the decision
            to overwrite the current save file is made
            based on either the maximization or the
            minimization of the monitored quantity. For `val_acc`,
            this should be `max`, for `val_loss` this should
            be `min`, etc. In `auto` mode, the direction is
            automatically inferred from the name of the monitored quantity.
        save_weights_only: if True, then only the model's weights will be
            saved (`model.save_weights(filepath)`), else the full model
            is saved (`model.save(filepath)`).
        period: Interval (number of epochs) between checkpoints.
    """

    def __init__(self, filepath,
                 monitor='val_loss',
                 save_best_only=False,
                 save_weights_only=False,
                 at_batch_end=None,
                 at_epoch_end=True,
                 mode='auto', period=1,
                 verbose=False):
        """
        Parameters:
            ...
            at_batch_end=None: None or number indicate when to execute
                (i.e. at_batch_end = 10 means execute every 10 batches)
            at_epoch_end=True: logical, whether to execute at epoch end
        """
        super(ModelCheckpoint, self).__init__()
        self.monitor = monitor
        self.verbose = verbose
        self.filepath = filepath
        self.save_best_only = save_best_only
        self.save_weights_only = save_weights_only
        self.period = period
        self.steps_since_last_save = 0

        if mode not in ['auto', 'min', 'max']:
            warnings.warn('ModelCheckpoint mode %s is unknown, '
                          'fallback to auto mode.' % (mode),
                          RuntimeWarning)
            mode = 'auto'

        if mode == 'min':
            self.monitor_op = np.less
            self.best = np.Inf
        elif mode == 'max':
            self.monitor_op = np.greater
            self.best = -np.Inf
        else:
            if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
                self.monitor_op = np.greater
                self.best = -np.Inf
            else:
                self.monitor_op = np.less
                self.best = np.Inf

        self.at_batch_end = at_batch_end
        self.at_epoch_end = at_epoch_end
        self.current_epoch = 0

    def on_epoch_begin(self, epoch, logs=None):
        self.current_epoch = epoch

    def on_batch_end(self, batch, logs=None):
        if self.at_batch_end is not None and np.mod(batch + 1, self.at_batch_end) == 0:
            print("Saving model at batch end!")
            self.on_model_save(self.current_epoch, batch + 1, logs=logs)

    def on_epoch_end(self, epoch, logs=None):
        if self.at_epoch_end:
            self.on_model_save(epoch, 0, logs=logs)
        self.current_epoch = epoch + 1

    def on_model_save(self, epoch, iter, logs=None):
        """ save the model to hdf5. Code mostly from keras core """

        with timer.Timer('model save callback', self.verbose):
            logs = logs or {}
            self.steps_since_last_save += 1
            if self.steps_since_last_save >= self.period:
                self.steps_since_last_save = 0
                filepath = self.filepath.format(epoch=epoch, iter=iter, **logs)
                if self.save_best_only:
                    current = logs.get(self.monitor)
                    if current is None:
                        warnings.warn('Can save best model only with %s available, '
                                      'skipping.' % (self.monitor), RuntimeWarning)
                    else:
                        if self.monitor_op(current, self.best):
                            if self.verbose > 0:
                                print('Epoch %05d Iter%05d: %s improved from %0.5f to %0.5f,'
                                      ' saving model to %s'
                                      % (epoch, iter, self.monitor, self.best,
                                         current, filepath))
                            self.best = current
                            if self.save_weights_only:
                                self.model.save_weights(filepath, overwrite=True)
                            else:
                                self.model.save(filepath, overwrite=True)
                        else:
                            if self.verbose > 0:
                                print('Epoch %05d Iter%05d: %s did not improve' %
                                      (epoch, iter, self.monitor))
                else:
                    if self.verbose > 0:
                        print('Epoch %05d: saving model to %s' % (epoch, filepath))
                    if self.save_weights_only:
                        self.model.save_weights(filepath, overwrite=True)
                    else:
                        self.model.save(filepath, overwrite=True)


class ModelCheckpointParallel(keras.callbacks.Callback):
    """
    
    borrow from: https://github.com/rmkemker/main/blob/master/machine_learning/model_checkpoint_parallel.py
    
    Save the model after every epoch.
    `filepath` can contain named formatting options,
    which will be filled the value of `epoch` and
    keys in `logs` (passed in `on_epoch_end`).
    For example: if `filepath` is `weights.{epoch:02d}-{val_loss:.2f}.hdf5`,
    then the model checkpoints will be saved with the epoch number and
    the validation loss in the filename.
    # Arguments
        filepath: string, path to save the model file.
        monitor: quantity to monitor.
        verbose: verbosity mode, 0 or 1.
        save_best_only: if `save_best_only=True`,
            the latest best model according to
            the quantity monitored will not be overwritten.
        mode: one of {auto, min, max}.
            If `save_best_only=True`, the decision
            to overwrite the current save file is made
            based on either the maximization or the
            minimization of the monitored quantity. For `val_acc`,
            this should be `max`, for `val_loss` this should
            be `min`, etc. In `auto` mode, the direction is
            automatically inferred from the name of the monitored quantity.
        save_weights_only: if True, then only the model's weights will be
            saved (`model.save_weights(filepath)`), else the full model
            is saved (`model.save(filepath)`).
        period: Interval (number of epochs) between checkpoints.
    """

    def __init__(self, filepath, monitor='val_loss', verbose=0,
                 save_best_only=False, save_weights_only=False,
                 at_batch_end=None,
                 at_epoch_end=True,
                 mode='auto', period=1):
        super(ModelCheckpointParallel, self).__init__()
        self.monitor = monitor
        self.verbose = verbose
        self.filepath = filepath
        self.save_best_only = save_best_only
        self.save_weights_only = save_weights_only
        self.period = period
        self.epochs_since_last_save = 0

        if mode not in ['auto', 'min', 'max']:
            warnings.warn('ModelCheckpointParallel mode %s is unknown, '
                          'fallback to auto mode.' % (mode),
                          RuntimeWarning)
            mode = 'auto'

        if mode == 'min':
            self.monitor_op = np.less
            self.best = np.Inf
        elif mode == 'max':
            self.monitor_op = np.greater
            self.best = -np.Inf
        else:
            if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
                self.monitor_op = np.greater
                self.best = -np.Inf
            else:
                self.monitor_op = np.less
                self.best = np.Inf

        self.at_batch_end = at_batch_end
        self.at_epoch_end = at_epoch_end
        self.current_epoch = 0

    def on_epoch_begin(self, epoch, logs=None):
        self.current_epoch = epoch

    def on_batch_end(self, batch, logs=None):
        if self.at_batch_end is not None and np.mod(batch + 1, self.at_batch_end) == 0:
            print("Saving model at batch end!")
            self.on_model_save(self.current_epoch, batch + 1, logs=logs)

    def on_epoch_end(self, epoch, logs=None):
        if self.at_epoch_end:
            self.on_model_save(epoch, 0, logs=logs)
        self.current_epoch = epoch + 1

    def on_model_save(self, epoch, iter, logs=None):
        """ save the model to hdf5. Code mostly from keras core """

        with timer.Timer('model save callback', self.verbose):
            logs = logs or {}
            num_outputs = len(self.model.outputs)
            self.epochs_since_last_save += 1
            if self.epochs_since_last_save >= self.period:
                self.epochs_since_last_save = 0
                filepath = self.filepath.format(epoch=epoch, iter=iter, **logs)
                if self.save_best_only:
                    current = logs.get(self.monitor)
                    if current is None:
                        warnings.warn('Can save best model only with %s available, '
                                    'skipping.' % (self.monitor), RuntimeWarning)
                    else:
                        if self.monitor_op(current, self.best):
                            if self.verbose > 0:
                                print('Epoch %05d: Iter%05d: %s improved from %0.5f to %0.5f,'
                                    ' saving model to %s'
                                    % (epoch, iter, self.monitor, self.best,
                                        current, filepath))
                            self.best = current
                            if self.save_weights_only:
                                self.model.layers[-(num_outputs+1)].save_weights(filepath, overwrite=True)
                            else:
                                self.model.layers[-(num_outputs+1)].save(filepath, overwrite=True)
                        else:
                            if self.verbose > 0:
                                print('Epoch %05d Iter%05d: %s did not improve' %
                                    (epoch, iter, self.monitor))
                else:
                    if self.verbose > 0:
                        print('Epoch %05d: saving model to %s' % (epoch, filepath))
                    if self.save_weights_only:
                        self.model.layers[-(num_outputs+1)].save_weights(filepath, overwrite=True)
                    else:
                        self.model.layers[-(num_outputs+1)].save(filepath, overwrite=True)


class TimeHistory(keras.callbacks.Callback):
    """
    taken from https://stackoverflow.com/questions/43178668/record-the-computation-time-for-each-epoch-in-keras-during-model-fit

    usecase:
    time_callback = TimeHistory()
    model.fit(..., callbacks=[..., time_callback],...)
    times = time_callback.times
    """
    def on_train_begin(self, logs={}):
        self.times = []

    def on_epoch_begin(self, batch, logs={}):
        self.epoch_time_start = time.time()

    def on_epoch_end(self, batch, logs={}):
        self.times.append(time.time() - self.epoch_time_start)


##################################################################################################
# helper functions
##################################################################################################

def _generate_predictions(model, data_generator, batch_size, nb_samples, vol_params):
    # whole volumes
    if vol_params is not None:
        for _ in range(nb_samples):  # assumes nr volume
            vols = ne.utils.seg.predict_volumes(model,
                                             data_generator,
                                             batch_size,
                                             vol_params["patch_size"],
                                             vol_params["patch_stride"],
                                             vol_params["grid_size"])
            vol_true, vol_pred = vols[0], vols[1]
            yield (vol_true, vol_pred)

    # just one batch
    else:
        for _ in range(nb_samples):  # assumes nr batches
            vol_pred, vol_true = ne.utils.seg.next_label(model, data_generator)
            yield (vol_true, vol_pred)

